import time
from dataclasses import dataclass
from typing import Callable, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import Attention, SpatialNorm
from diffusers.models.resnet import ResnetBlock2D
from diffusers.models.unets.unet_2d_blocks import (
    AutoencoderTinyBlock,
    get_down_block,
    get_up_block,
    UNetMidBlock2D,
)
from diffusers.utils import BaseOutput, is_torch_version
from diffusers.utils.torch_utils import randn_tensor

xops = None
AttentionOp = None


def get_model_parallel_size() -> int:
    try:
        if (
            torch.distributed.is_initialized()
            and fs_init.model_parallel_is_initialized()
        ):
            return fs_init.get_model_parallel_world_size()
        else:
            return 1
    except:
        return 1


@dataclass
class InitArgs:
    use_gaussian: bool = True  # gaussian vs uniform
    use_depth: Optional[str] = None  # rescale by depth
    fixed_std: Optional[float] = None  # fixed std (otherwise, use 1/sqrt(dim))
    coeff_std: Optional[float] = None  # std coeff multiplier
    depth_last: bool = False
    no_init: bool = False
    pos_init_scalar: Optional[float] = (
        None  # factor to scale learned positional embedding init by
    )

    def __post_init__(self):
        assert self.use_depth in [None, "current", "global"]
        assert (self.fixed_std is None) or (self.coeff_std is None)
        assert not self.depth_last or self.use_depth is not None


def get_init_fn(
    args: InitArgs, input_dim: int, zero_out: bool = False
) -> Callable[[torch.Tensor], torch.Tensor]:
    """
    Init functions.
    """
    if args.no_init:
        return lambda x: x

    if zero_out:
        return partial(torch.nn.init.constant_, val=0.0)

    std = 1 / math.sqrt(input_dim)
    std = std if args.coeff_std is None else (args.coeff_std * std)

    # gaussian vs uniform
    if args.use_gaussian:
        return partial(
            torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
        )
    else:
        bound = math.sqrt(3) * std  # ensure the standard deviation is `std`
        return partial(torch.nn.init.uniform_, a=-bound, b=bound)


class _InnerAttention(torch.nn.Module):
    """
    Inner attention that is checkpointed if selective activation checkpointing
    """

    def __init__(
        self,
        head_dim: int,
        n_heads: int,
        efficient_attn: Optional[str],
    ):
        super().__init__()
        self.head_dim = head_dim
        self.n_heads = n_heads
        self.efficient_attn = efficient_attn

        # efficient attention
        if self.efficient_attn is not None:
            self.attn_op = get_attn_op(self.efficient_attn)

        # model parallel world size
        self.model_parallel_size = get_model_parallel_size()

        assert self.n_heads % self.model_parallel_size == 0
        self.n_local_heads = self.n_heads // self.model_parallel_size


    def forward(
        self,
        xq: torch.Tensor,
        xk: torch.Tensor,
        xv: torch.Tensor,
    ):
        bs, slen, _ = xq.shape
        bs, klen, _ = xk.shape

        xq = xq.view(bs, slen, self.n_local_heads, self.head_dim)
        xk = xk.view(bs, klen, self.n_local_heads, self.head_dim)
        xv = xv.view(bs, klen, self.n_local_heads, self.head_dim)


        if self.efficient_attn is None:
            # NOTE: maybe to consider https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/118
            xq = xq.transpose(1, 2)  # (bs, n_local_heads, slen, head_dim)
            xk = xk.transpose(1, 2)  # (bs, n_local_heads, klen, head_dim)
            xv = xv.transpose(1, 2)  # (bs, n_local_heads, klen, head_dim)
            N = xq.size(-2)
            # (bs, n_local_heads, slen, cache_len + slen)
            scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, xv)  # (bs, n_heads, slen, head_dim)
            output = output.transpose(
                1, 2
            ).contiguous()  # (bs, slen, n_local_heads, head_dim)
        else:
            assert xops is not None
            output = xops.memory_efficient_attention(xq, xk, xv, op=self.attn_op)

        output = output.view(bs, slen, -1)
        return output


class NewAttention(torch.nn.Module):
    """
    Attention has an inner module to allow selective activation checkpointing
    """

    def __init__(
        self,
        dim: int,
        head_dim: int,
        n_heads: int,
        dropout: float,
        efficient_attn: Optional[str],
        init_args: InitArgs,
        cross_attention_dim: int = None,
        norm_num_groups: int = None,
        norm_eps: float = None,
        use_layer_norm: bool = False,
    ):
        super().__init__()

        self.use_layer_norm = use_layer_norm
        if use_layer_norm:
            self.norm = FusedLayerNorm(dim, eps=1e-5, elementwise_affine=True)
        else:
            self.norm = nn.GroupNorm(
                num_channels=dim, num_groups=norm_num_groups, eps=norm_eps, affine=True
            )

        init = get_init_fn(init_args, dim)
        self.wq = ColumnParallelLinear(
            dim, n_heads * head_dim, bias=False, gather_output=False, init_method=init
        )
        kv_dim = dim if cross_attention_dim is None else cross_attention_dim
        self.wk = ColumnParallelLinear(
            kv_dim,
            n_heads * head_dim,
            bias=False,
            gather_output=False,
            init_method=init,
        )
        self.wv = ColumnParallelLinear(
            kv_dim,
            n_heads * head_dim,
            bias=False,
            gather_output=False,
            init_method=init,
        )
        self.inner_attention = _InnerAttention(
            head_dim=head_dim,
            n_heads=n_heads,
            efficient_attn=efficient_attn,
        )
        init = get_init_fn(init_args, dim)
        self.wo = RowParallelLinear(
            n_heads * head_dim,
            dim,
            bias=False,
            input_is_parallel=True,
            init_method=init,
        )

        self.memory_eff_dropout = nn.Dropout(dropout)  ###??fixed MemoryEfficientDropout

    def forward(
        self,
        x: torch.Tensor,
        temb: torch.Tensor = None,
    ):
        # k is already normalized
        # x is in the shape of  (batch_size, seq_len, channel) where seq_len = height * width
        input_ndim = x.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = x.shape
            x = x.reshape(batch_size, channel, height * width).transpose(1, 2)

        residual = x
        if self.use_layer_norm:
            x = self.norm(x)
        else:
            x = self.norm(x.transpose(1, 2)).transpose(1, 2)

        xq = self.wq(x)
        xk = self.wk(x if temb is None else temb)
        xv = self.wv(x if temb is None else temb)

        output = self.inner_attention(xq, xk, xv)
        output = self.memory_eff_dropout(self.wo(output))
        output = residual + output

        if input_ndim == 4:
            output = output.transpose(-1, -2).reshape(
                batch_size, channel, height, width
            )

        return output.contiguous()


def replace_layers(model):
    init_args = InitArgs(no_init=True)
    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            ## compound module, go inside it
            replace_layers(module)

        if type(module) == Attention:
            ## simple module
            print("replaced: ", module)
            new_gn = NewAttention(
                module.query_dim,
                512,
                1,
                module.dropout,
                efficient_attn="cutlass",
                init_args=init_args,
                norm_num_groups=32,
                norm_eps=1e-6,
                use_layer_norm=False,
            )
            setattr(model, n, new_gn)


@dataclass
class DecoderOutput(BaseOutput):
    r"""
    Output of decoding method.

    Args:
        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
            The decoded output sample from the last layer of the model.
    """

    sample: torch.Tensor
    commit_loss: Optional[torch.FloatTensor] = None


class Encoder(nn.Module):
    r"""
    The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
            The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
            options.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        act_fn (`str`, *optional*, defaults to `"silu"`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
        double_z (`bool`, *optional*, defaults to `True`):
            Whether to double the number of output channels for the last block.
    """

    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        act_fn: str = "silu",
        double_z: bool = True,
        num_layers: int = 1,
        mid_block_attention_head_dim: int = 1,
        mid_block_add_attention=True,
        mid_block_processing="pad",
        starting_block_idx=1,
        compression_range=4,
    ):
        super().__init__()
        self.layers_per_block = layers_per_block
        self.block_out_channels = block_out_channels
        self.mid_block_processing = mid_block_processing
        self.starting_block_idx = starting_block_idx
        self.compression_range = compression_range

        self.conv_in = nn.Conv2d(
            in_channels,
            block_out_channels[0],
            kernel_size=3,
            stride=1,
            padding=1,
        )

        self.down_blocks = nn.ModuleList([])

        # down
        output_channel = block_out_channels[0]

        self.channel_matching = []
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=self.layers_per_block,
                in_channels=input_channel,
                out_channels=output_channel,
                add_downsample=not is_final_block,
                resnet_eps=1e-6,
                downsample_padding=0,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attention_head_dim=output_channel,
                temb_channels=None,
            )
            if is_final_block:
                self.pre_mid_block = down_block
            else:
                self.down_blocks.append(down_block)

            if (
                self.mid_block_processing == "conv"
                and i >= self.starting_block_idx
                and output_channel != block_out_channels[-1]
            ):
                self.channel_matching.append(
                    ResnetBlock2D(
                        in_channels=output_channel,
                        out_channels=block_out_channels[-1],
                        temb_channels=None,
                        eps=1e-6,
                        groups=norm_num_groups,
                        time_embedding_norm="default",
                        non_linearity=act_fn,
                        output_scale_factor=1,
                    )
                )

        if len(self.channel_matching) > 0:
            self.channel_matching = nn.ModuleList(self.channel_matching)

        # mid
        self.mid_block = UNetMidBlock2D(
            in_channels=block_out_channels[-1],
            num_layers=num_layers,
            resnet_eps=1e-6,
            resnet_act_fn=act_fn,
            output_scale_factor=1,
            resnet_time_scale_shift="default",
            attention_head_dim=block_out_channels[-1] // mid_block_attention_head_dim,
            resnet_groups=norm_num_groups,
            temb_channels=None,
            add_attention=mid_block_add_attention,
        )

        # out
        self.conv_norm_out = nn.GroupNorm(
            num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
        )
        self.conv_norm_out_list = None  # nn.ModuleList([nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)]*3)
        self.conv_act = nn.SiLU()

        conv_out_channels = 2 * out_channels if double_z else out_channels
        self.conv_out = nn.Conv2d(
            block_out_channels[-1], conv_out_channels, 3, padding=1
        )
        self.conv_out_list = None  # nn.ModuleList([nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)]*3)
        self.conv_out_channels = conv_out_channels

        self.gradient_checkpointing = False

    def forward(
        self, sample: torch.Tensor, compression_rates: torch.Tensor
    ) -> torch.Tensor:
        r"""The forward method of the `Encoder` class."""

        sample = self.conv_in(sample)
        all_same_rate = compression_rates.min().item() == compression_rates.max().item()

        if all_same_rate:
            # down
            spatial_dims = []
            min_compression_rate = int(compression_rates.min().item())

            cmatch_idx = 0
            for bidx, down_block in enumerate(self.down_blocks):
                sample = down_block(sample)
                if bidx >= self.starting_block_idx:
                    spatial_dims.append(sample.shape[-1])
                    if self.block_out_channels[-1] != sample.shape[1]:
                        if bidx == self.starting_block_idx + min_compression_rate:
                            if self.mid_block_processing == "pad":
                                sample = F.pad(
                                    sample,
                                    (
                                        0,
                                        0,
                                        0,
                                        0,
                                        0,
                                        self.block_out_channels[-1] - sample.shape[1],
                                    ),
                                    "constant",
                                    0,
                                )
                            else:
                                sample = self.channel_matching[cmatch_idx](
                                    sample, temb=None
                                )
                            break
                        else:
                            cmatch_idx += 1

            sample = self.pre_mid_block(sample)
            sample = self.mid_block(sample)
            if self.conv_norm_out_list is None:
                cno = self.conv_norm_out
                co = self.conv_out
            else:
                cno = self.conv_norm_out_list[min_compression_rate - 1]
                co = self.conv_out_list[min_compression_rate - 1]
            sample = cno(sample)
            sample = self.conv_act(sample)
            sample = co(sample)
            return sample, spatial_dims

        spatial_dims = []
        cmatch_idx = 0
        all_samples = []

        for bidx, down_block in enumerate(self.down_blocks):
            sample = down_block(sample)
            if bidx >= self.starting_block_idx:
                spatial_dims.append(sample.shape[-1])
                if self.block_out_channels[-1] != sample.shape[1]:
                    if self.mid_block_processing == "pad":
                        padded_sample = F.pad(
                            sample,
                            (
                                0,
                                0,
                                0,
                                0,
                                0,
                                self.block_out_channels[-1] - sample.shape[1],
                            ),
                            "constant",
                            0,
                        )
                    else:
                        padded_sample = self.channel_matching[cmatch_idx](
                            sample, temb=None
                        )
                        cmatch_idx += 1
                else:
                    padded_sample = sample

                all_samples.append(padded_sample)

        max_spatial_dim = spatial_dims[int(compression_rates.min().item())]
        out = torch.zeros(
            (sample.shape[0], self.conv_out_channels, max_spatial_dim, max_spatial_dim),
            dtype=sample.dtype,
            device=sample.device,
        )
        for c_rate in range(self.compression_range):
            if (compression_rates == c_rate).sum() > 0:
                selected_sample = all_samples[c_rate][compression_rates == c_rate]
                sample = self.pre_mid_block(selected_sample)
                sample = self.mid_block(sample)

                # post-process
                sample = self.conv_norm_out(sample)
                sample = self.conv_act(sample)
                sample = self.conv_out(sample)
                out[compression_rates == c_rate] = F.pad(
                    sample,
                    (
                        0,
                        max_spatial_dim - sample.shape[-1],
                        0,
                        max_spatial_dim - sample.shape[-2],
                    ),
                    "constant",
                    0,
                )

        del all_samples

        return out, spatial_dims


class Decoder(nn.Module):
    r"""
    The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
            The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        act_fn (`str`, *optional*, defaults to `"silu"`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
        norm_type (`str`, *optional*, defaults to `"group"`):
            The normalization type to use. Can be either `"group"` or `"spatial"`.
    """

    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        mid_block_supervision: bool = True,
        act_fn: str = "silu",
        num_layers: int = 1,
        mid_block_attention_head_dim: int = 1,
        norm_type: str = "group",  # group, spatial
        mid_block_add_attention=True,
        mid_block_processing="pad",
        starting_block_idx=1,
        compression_range=4,
    ):
        super().__init__()
        self.layers_per_block = layers_per_block
        self.mid_block_processing = mid_block_processing
        self.starting_block_idx = starting_block_idx
        self.block_out_channels = block_out_channels
        self.compression_range = compression_range

        self.conv_in = nn.Conv2d(
            in_channels,
            block_out_channels[-1],
            kernel_size=3,
            stride=1,
            padding=1,
        )
        self.conv_in_list = nn.ModuleList(
            [
                nn.Conv2d(
                    in_channels,
                    block_out_channels[-1],
                    kernel_size=3,
                    stride=1,
                    padding=1,
                )
            ]
            * 3
        )
        self.conv_in_list = None

        self.up_blocks = nn.ModuleList([])

        temb_channels = in_channels if norm_type == "spatial" else None

        # mid
        self.mid_block = UNetMidBlock2D(
            in_channels=block_out_channels[-1],
            resnet_eps=1e-6,
            resnet_act_fn=act_fn,
            output_scale_factor=1,
            num_layers=num_layers,
            resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
            attention_head_dim=block_out_channels[-1] // mid_block_attention_head_dim,
            resnet_groups=norm_num_groups,
            temb_channels=temb_channels,
            add_attention=mid_block_add_attention,
        )
        print("dec mid_block", sum(p.numel() for p in self.mid_block.parameters()))

        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
        self.reversed_block_out_channels = reversed_block_out_channels
        output_channel = reversed_block_out_channels[0]

        self.channel_matching = []
        self.mid_supervision = [
            nn.Sequential(
                nn.GroupNorm(
                    num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
                ),
                get_activation(act_fn),
                nn.Conv2d(
                    in_channels=output_channel, out_channels=3, kernel_size=3, padding=1
                ),
            )
        ]
        for i, up_block_type in enumerate(up_block_types):
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]

            is_final_block = i == len(block_out_channels) - 1

            up_block = get_up_block(
                up_block_type,
                num_layers=self.layers_per_block + 1,
                in_channels=prev_output_channel,
                out_channels=output_channel,
                prev_output_channel=None,
                add_upsample=not is_final_block,
                resnet_eps=1e-6,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attention_head_dim=output_channel,
                temb_channels=temb_channels,
                resnet_time_scale_shift=norm_type,
            )
            self.up_blocks.append(up_block)

            prev_output_channel = output_channel

            if (
                self.mid_block_processing == "conv"
                and i >= self.starting_block_idx
                and output_channel != reversed_block_out_channels[0]
            ):
                self.channel_matching.append(
                    ResnetBlock2D(
                        in_channels=reversed_block_out_channels[0],
                        out_channels=output_channel,
                        temb_channels=None,
                        eps=1e-6,
                        groups=norm_num_groups,
                        time_embedding_norm="default",
                        non_linearity=act_fn,
                        output_scale_factor=1,
                    )
                )  # nn.Conv2d(reversed_block_out_channels[0], output_channel, kernel_size=1, stride=1, padding=0))
            if mid_block_supervision:
                self.mid_supervision.append(
                    nn.Sequential(
                        nn.GroupNorm(
                            num_channels=output_channel,
                            num_groups=norm_num_groups,
                            eps=1e-6,
                        ),
                        get_activation(act_fn),
                        nn.Conv2d(
                            in_channels=output_channel,
                            out_channels=3,
                            kernel_size=3,
                            padding=1,
                        ),
                    )
                )
                nn.init.zeros_(self.mid_supervision[-1][-1].bias)

        if len(self.channel_matching) > 0:
            self.channel_matching = nn.ModuleList(self.channel_matching)
        if len(self.mid_supervision) > 0:
            self.mid_supervision = nn.ModuleList(self.mid_supervision)
        self.mid_block_supervision = mid_block_supervision

        # out
        if norm_type == "spatial":
            self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
        else:
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
            )
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

        self.gradient_checkpointing = False

    def forward(
        self,
        sample: torch.Tensor,
        compression_rates: torch.Tensor,
        spatial_dims: list,
        label: Optional[torch.Tensor] = None,
        latent_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        r"""The forward method of the `Decoder` class."""

        upscale_dtype = self.conv_out.weight.dtype
        all_same_rate = compression_rates.min().item() == compression_rates.max().item()

        if all_same_rate:

            cmatch_idx = 0
            max_spatial_dim = sample.shape[-1]
            if self.conv_in_list is None:
                ci = self.conv_in
            else:
                ci = self.conv_in_list[compression_rates.min().item() - 1]
            sample = ci(sample)
            sample = self.mid_block(sample, latent_embeds).to(upscale_dtype)
            start_block_idx = (
                len(self.up_blocks) - compression_rates - self.starting_block_idx - 2
            )
            pass_start_block = False
            already_optimized = []
            mid_loss = 0
            for bidx, up_block in enumerate(self.up_blocks):
                mask = bidx == start_block_idx
                if mask.sum() > 0 or pass_start_block:
                    if (
                        label is not None
                        and self.mid_block_supervision
                        and sample.shape[-1] not in already_optimized
                        and sample.shape[-1] != label.shape[-1]
                    ):
                        already_optimized.append(sample.shape[-1])
                        target = F.avg_pool2d(
                            label, label.shape[-1] // sample.shape[-1]
                        )
                        mid_pred = self.mid_supervision[
                            bidx - self.starting_block_idx + 1
                        ](sample)
                        mid_loss += F.mse_loss(mid_pred, target, reduction="mean")
                if mask.sum() > 0:
                    pass_start_block = True
                    if self.mid_block_processing == "pad":
                        sample = sample[
                            :, : self.reversed_block_out_channels[max(0, bidx - 1)], ...
                        ]  #:spatial_dims[-1 - bidx], :spatial_dims[-1 - bidx]] + (1 - mask) * sample #(1 - mask) *
                    elif (
                        self.reversed_block_out_channels[max(0, bidx - 1)]
                        != sample.shape[1]
                    ):
                        sample = self.channel_matching[cmatch_idx](
                            sample, temb=None
                        )  # sample_in[..., :spatial_dims[-1 - bidx], :spatial_dims[-1 - bidx]]) + (1 - mask) * sample
                        cmatch_idx += 1
                    sample = up_block(sample, latent_embeds)
                elif pass_start_block:
                    sample = up_block(sample, latent_embeds)

            if latent_embeds is None:
                sample = self.conv_norm_out(sample)
            else:
                sample = self.conv_norm_out(sample, latent_embeds)
            sample = self.conv_act(sample)
            sample = self.conv_out(sample)

            return sample, mid_loss

        else:
            cmatch_idx = 0
            mid_loss = 0
            max_spatial_dim = sample.shape[-1]
            out = torch.zeros(
                (
                    sample.shape[0],
                    self.reversed_block_out_channels[0],
                    max_spatial_dim,
                    max_spatial_dim,
                ),
                dtype=sample.dtype,
                device=sample.device,
            )
            for c_rate in range(self.compression_range):
                if (compression_rates == c_rate).sum() > 0:

                    selected_sample = sample[compression_rates == c_rate][
                        ..., : spatial_dims[c_rate], : spatial_dims[c_rate]
                    ]
                    selected_sample = self.conv_in(selected_sample)
                    selected_sample = self.mid_block(selected_sample, latent_embeds).to(
                        upscale_dtype
                    )
                    out[compression_rates == c_rate] = F.pad(
                        selected_sample,
                        (
                            0,
                            max_spatial_dim - selected_sample.shape[-1],
                            0,
                            max_spatial_dim - selected_sample.shape[-2],
                        ),
                        "constant",
                        0,
                    )

            sample = out[..., : spatial_dims[-1], : spatial_dims[-1]]
            start_block_idx = (
                len(self.up_blocks) - compression_rates - self.starting_block_idx - 2
            )
            for bidx, up_block in enumerate(self.up_blocks):

                mask = bidx == start_block_idx
                if mask.sum() > 0:
                    selected_sample = out[..., : sample.shape[2], : sample.shape[3]]
                    if self.mid_block_processing == "pad":
                        selected_sample = selected_sample[:, : sample.shape[1], ...]
                    elif out.shape[1] != sample.shape[1]:
                        selected_sample = self.channel_matching[cmatch_idx](
                            selected_sample, temb=None
                        )
                        cmatch_idx += 1
                    sample[mask] = selected_sample[mask]
                sample = up_block(sample, latent_embeds)

            if latent_embeds is None:
                sample = self.conv_norm_out(sample)
            else:
                sample = self.conv_norm_out(sample, latent_embeds)
            sample = self.conv_act(sample)
            sample = self.conv_out(sample)

            return sample, mid_loss


class UpSample(nn.Module):
    r"""
    The `UpSample` layer of a variational autoencoder that upsamples its input.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
    ) -> None:
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.deconv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=4, stride=2, padding=1
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""The forward method of the `UpSample` class."""
        x = torch.relu(x)
        x = self.deconv(x)
        return x


class MaskConditionEncoder(nn.Module):
    """
    used in AsymmetricAutoencoderKL
    """

    def __init__(
        self,
        in_ch: int,
        out_ch: int = 192,
        res_ch: int = 768,
        stride: int = 16,
    ) -> None:
        super().__init__()

        channels = []
        while stride > 1:
            stride = stride // 2
            in_ch_ = out_ch * 2
            if out_ch > res_ch:
                out_ch = res_ch
            if stride == 1:
                in_ch_ = res_ch
            channels.append((in_ch_, out_ch))
            out_ch *= 2

        out_channels = []
        for _in_ch, _out_ch in channels:
            out_channels.append(_out_ch)
        out_channels.append(channels[-1][0])

        layers = []
        in_ch_ = in_ch
        for l in range(len(out_channels)):
            out_ch_ = out_channels[l]
            if l == 0 or l == 1:
                layers.append(
                    nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)
                )
            else:
                layers.append(
                    nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)
                )
            in_ch_ = out_ch_

        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
        r"""The forward method of the `MaskConditionEncoder` class."""
        out = {}
        for l in range(len(self.layers)):
            layer = self.layers[l]
            x = layer(x)
            out[str(tuple(x.shape))] = x
            x = torch.relu(x)
        return out


class MaskConditionDecoder(nn.Module):
    r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
    decoder with a conditioner on the mask and masked image.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
            The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        act_fn (`str`, *optional*, defaults to `"silu"`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
        norm_type (`str`, *optional*, defaults to `"group"`):
            The normalization type to use. Can be either `"group"` or `"spatial"`.
    """

    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: int = 2,
        norm_num_groups: int = 32,
        act_fn: str = "silu",
        norm_type: str = "group",  # group, spatial
    ):
        super().__init__()
        self.layers_per_block = layers_per_block

        self.conv_in = nn.Conv2d(
            in_channels,
            block_out_channels[-1],
            kernel_size=3,
            stride=1,
            padding=1,
        )

        self.up_blocks = nn.ModuleList([])

        temb_channels = in_channels if norm_type == "spatial" else None

        # mid
        self.mid_block = UNetMidBlock2D(
            in_channels=block_out_channels[-1],
            resnet_eps=1e-6,
            resnet_act_fn=act_fn,
            output_scale_factor=1,
            resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
            attention_head_dim=block_out_channels[-1],
            resnet_groups=norm_num_groups,
            temb_channels=temb_channels,
        )

        # up
        reversed_block_out_channels = list(reversed(block_out_channels))
        output_channel = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(up_block_types):
            prev_output_channel = output_channel
            output_channel = reversed_block_out_channels[i]

            is_final_block = i == len(block_out_channels) - 1

            up_block = get_up_block(
                up_block_type,
                num_layers=self.layers_per_block + 1,
                in_channels=prev_output_channel,
                out_channels=output_channel,
                prev_output_channel=None,
                add_upsample=not is_final_block,
                resnet_eps=1e-6,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attention_head_dim=output_channel,
                temb_channels=temb_channels,
                resnet_time_scale_shift=norm_type,
            )
            self.up_blocks.append(up_block)
            prev_output_channel = output_channel

        # condition encoder
        self.condition_encoder = MaskConditionEncoder(
            in_ch=out_channels,
            out_ch=block_out_channels[0],
            res_ch=block_out_channels[-1],
        )

        # out
        if norm_type == "spatial":
            self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
        else:
            self.conv_norm_out = nn.GroupNorm(
                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
            )
        self.conv_act = nn.SiLU()
        self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)

        self.gradient_checkpointing = False

    def forward(
        self,
        z: torch.Tensor,
        image: Optional[torch.Tensor] = None,
        mask: Optional[torch.Tensor] = None,
        latent_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        r"""The forward method of the `MaskConditionDecoder` class."""
        sample = z
        sample = self.conv_in(sample)

        upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module):
                def custom_forward(*inputs):
                    return module(*inputs)

                return custom_forward

            if is_torch_version(">=", "1.11.0"):
                # middle
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block),
                    sample,
                    latent_embeds,
                    use_reentrant=False,
                )
                sample = sample.to(upscale_dtype)

                # condition encoder
                if image is not None and mask is not None:
                    masked_image = (1 - mask) * image
                    im_x = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(self.condition_encoder),
                        masked_image,
                        mask,
                        use_reentrant=False,
                    )

                # up
                for up_block in self.up_blocks:
                    if image is not None and mask is not None:
                        sample_ = im_x[str(tuple(sample.shape))]
                        mask_ = nn.functional.interpolate(
                            mask, size=sample.shape[-2:], mode="nearest"
                        )
                        sample = sample * mask_ + sample_ * (1 - mask_)
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(up_block),
                        sample,
                        latent_embeds,
                        use_reentrant=False,
                    )
                if image is not None and mask is not None:
                    sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
            else:
                # middle
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block), sample, latent_embeds
                )
                sample = sample.to(upscale_dtype)

                # condition encoder
                if image is not None and mask is not None:
                    masked_image = (1 - mask) * image
                    im_x = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(self.condition_encoder),
                        masked_image,
                        mask,
                    )

                # up
                for up_block in self.up_blocks:
                    if image is not None and mask is not None:
                        sample_ = im_x[str(tuple(sample.shape))]
                        mask_ = nn.functional.interpolate(
                            mask, size=sample.shape[-2:], mode="nearest"
                        )
                        sample = sample * mask_ + sample_ * (1 - mask_)
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(up_block), sample, latent_embeds
                    )
                if image is not None and mask is not None:
                    sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
        else:
            # middle
            sample = self.mid_block(sample, latent_embeds)
            sample = sample.to(upscale_dtype)

            # condition encoder
            if image is not None and mask is not None:
                masked_image = (1 - mask) * image
                im_x = self.condition_encoder(masked_image, mask)

            # up
            for up_block in self.up_blocks:
                if image is not None and mask is not None:
                    sample_ = im_x[str(tuple(sample.shape))]
                    mask_ = nn.functional.interpolate(
                        mask, size=sample.shape[-2:], mode="nearest"
                    )
                    sample = sample * mask_ + sample_ * (1 - mask_)
                sample = up_block(sample, latent_embeds)
            if image is not None and mask is not None:
                sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)

        # post-process
        if latent_embeds is None:
            sample = self.conv_norm_out(sample)
        else:
            sample = self.conv_norm_out(sample, latent_embeds)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        return sample


class VectorQuantizer(nn.Module):
    """
    Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
    multiplications and allows for post-hoc remapping of indices.
    """

    # NOTE: due to a bug the beta term was applied to the wrong term. for
    # backwards compatibility we use the buggy version by default, but you can
    # specify legacy=False to fix it.
    def __init__(
        self,
        n_e: int,
        vq_embed_dim: int,
        beta: float,
        remap=None,
        unknown_index: str = "random",
        sane_index_shape: bool = False,
        legacy: bool = True,
    ):
        super().__init__()
        self.n_e = n_e
        self.vq_embed_dim = vq_embed_dim
        self.beta = beta
        self.legacy = legacy

        self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)

        self.remap = remap
        if self.remap is not None:
            self.register_buffer("used", torch.tensor(np.load(self.remap)))
            self.used: torch.Tensor
            self.re_embed = self.used.shape[0]
            self.unknown_index = unknown_index  # "random" or "extra" or integer
            if self.unknown_index == "extra":
                self.unknown_index = self.re_embed
                self.re_embed = self.re_embed + 1
            print(
                f"Remapping {self.n_e} indices to {self.re_embed} indices. "
                f"Using {self.unknown_index} for unknown indices."
            )
        else:
            self.re_embed = n_e

        self.sane_index_shape = sane_index_shape

    def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
        ishape = inds.shape
        assert len(ishape) > 1
        inds = inds.reshape(ishape[0], -1)
        used = self.used.to(inds)
        match = (inds[:, :, None] == used[None, None, ...]).long()
        new = match.argmax(-1)
        unknown = match.sum(2) < 1
        if self.unknown_index == "random":
            new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
                device=new.device
            )
        else:
            new[unknown] = self.unknown_index
        return new.reshape(ishape)

    def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
        ishape = inds.shape
        assert len(ishape) > 1
        inds = inds.reshape(ishape[0], -1)
        used = self.used.to(inds)
        if self.re_embed > self.used.shape[0]:  # extra token
            inds[inds >= self.used.shape[0]] = 0  # simply set to zero
        back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
        return back.reshape(ishape)

    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Tuple]:
        # reshape z -> (batch, height, width, channel) and flatten
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.shape(-1, self.vq_embed_dim)

        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        min_encoding_indices = torch.argmin(
            torch.cdist(z_flattened, self.embedding.weight), dim=1
        )

        z_q = self.embedding(min_encoding_indices).reshape(z.shape)
        perplexity = None
        min_encodings = None

        # compute loss for embedding
        if not self.legacy:
            loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
                (z_q - z.detach()) ** 2
            )
        else:
            loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
                (z_q - z.detach()) ** 2
            )

        # preserve gradients
        z_q: torch.Tensor = z + (z_q - z).detach()

        # reshape back to match original input shape
        z_q = z_q.permute(0, 3, 1, 2).contiguous()

        if self.remap is not None:
            min_encoding_indices = min_encoding_indices.reshape(
                z.shape[0], -1
            )  # add batch axis
            min_encoding_indices = self.remap_to_used(min_encoding_indices)
            min_encoding_indices = min_encoding_indices.reshape(-1, 1)  # flatten

        if self.sane_index_shape:
            min_encoding_indices = min_encoding_indices.reshape(
                z_q.shape[0], z_q.shape[2], z_q.shape[3]
            )

        return z_q, loss, (perplexity, min_encodings, min_encoding_indices)

    def get_codebook_entry(
        self, indices: torch.LongTensor, shape: Tuple[int, ...]
    ) -> torch.Tensor:
        # shape specifying (batch, height, width, channel)
        if self.remap is not None:
            indices = indices.reshape(shape[0], -1)  # add batch axis
            indices = self.unmap_to_all(indices)
            indices = indices.reshape(-1)  # flatten again

        # get quantized latent vectors
        z_q: torch.Tensor = self.embedding(indices)

        if shape is not None:
            z_q = z_q.reshape(shape)
            # reshape back to match original input shape
            z_q = z_q.permute(0, 3, 1, 2).contiguous()

        return z_q


class DiagonalGaussianDistribution(object):
    def __init__(
        self, parameters: torch.Tensor, mask: torch.Tensor, deterministic: bool = False
    ):
        self.parameters = parameters
        self.mask = mask
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(
                self.mean, device=self.parameters.device, dtype=self.parameters.dtype
            )

    def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
        # make sure sample is on the same device as the parameters and has same dtype
        sample = randn_tensor(
            self.mean.shape,
            generator=generator,
            device=self.parameters.device,
            dtype=self.parameters.dtype,
        )
        x = self.mean + self.std * sample
        x = x * self.mask

        return x

    def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
        if self.deterministic:
            return torch.Tensor([0.0])
        else:
            if other is None:
                return 0.5 * torch.sum(
                    (torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar)
                    * self.mask,
                    dim=[1, 2, 3],
                )
            else:
                return 0.5 * torch.sum(
                    (
                        torch.pow(self.mean - other.mean, 2) / other.var
                        + self.var / other.var
                        - 1.0
                        - self.logvar
                        + other.logvar
                    )
                    * self.mask,
                    dim=[1, 2, 3],
                )

    def nll(
        self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
    ) -> torch.Tensor:
        if self.deterministic:
            return torch.Tensor([0.0])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims,
        )

    def mode(self) -> torch.Tensor:
        return self.mean


class EncoderTiny(nn.Module):
    r"""
    The `EncoderTiny` layer is a simpler version of the `Encoder` layer.

    Args:
        in_channels (`int`):
            The number of input channels.
        out_channels (`int`):
            The number of output channels.
        num_blocks (`Tuple[int, ...]`):
            Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
            use.
        block_out_channels (`Tuple[int, ...]`):
            The number of output channels for each block.
        act_fn (`str`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_blocks: Tuple[int, ...],
        block_out_channels: Tuple[int, ...],
        act_fn: str,
    ):
        super().__init__()

        layers = []
        for i, num_block in enumerate(num_blocks):
            num_channels = block_out_channels[i]

            if i == 0:
                layers.append(
                    nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)
                )
            else:
                layers.append(
                    nn.Conv2d(
                        num_channels,
                        num_channels,
                        kernel_size=3,
                        padding=1,
                        stride=2,
                        bias=False,
                    )
                )

            for _ in range(num_block):
                layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))

        layers.append(
            nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)
        )

        self.layers = nn.Sequential(*layers)
        self.gradient_checkpointing = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""The forward method of the `EncoderTiny` class."""
        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module):
                def custom_forward(*inputs):
                    return module(*inputs)

                return custom_forward

            if is_torch_version(">=", "1.11.0"):
                x = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.layers), x, use_reentrant=False
                )
            else:
                x = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.layers), x
                )

        else:
            # scale image from [-1, 1] to [0, 1] to match TAESD convention
            x = self.layers(x.add(1).div(2))

        return x


class DecoderTiny(nn.Module):
    r"""
    The `DecoderTiny` layer is a simpler version of the `Decoder` layer.

    Args:
        in_channels (`int`):
            The number of input channels.
        out_channels (`int`):
            The number of output channels.
        num_blocks (`Tuple[int, ...]`):
            Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
            use.
        block_out_channels (`Tuple[int, ...]`):
            The number of output channels for each block.
        upsampling_scaling_factor (`int`):
            The scaling factor to use for upsampling.
        act_fn (`str`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_blocks: Tuple[int, ...],
        block_out_channels: Tuple[int, ...],
        upsampling_scaling_factor: int,
        act_fn: str,
        upsample_fn: str,
    ):
        super().__init__()

        layers = [
            nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
            get_activation(act_fn),
        ]

        for i, num_block in enumerate(num_blocks):
            is_final_block = i == (len(num_blocks) - 1)
            num_channels = block_out_channels[i]

            for _ in range(num_block):
                layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))

            if not is_final_block:
                layers.append(
                    nn.Upsample(
                        scale_factor=upsampling_scaling_factor, mode=upsample_fn
                    )
                )

            conv_out_channel = num_channels if not is_final_block else out_channels
            layers.append(
                nn.Conv2d(
                    num_channels,
                    conv_out_channel,
                    kernel_size=3,
                    padding=1,
                    bias=is_final_block,
                )
            )

        self.layers = nn.Sequential(*layers)
        self.gradient_checkpointing = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""The forward method of the `DecoderTiny` class."""
        # Clamp.
        x = torch.tanh(x / 3) * 3

        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module):
                def custom_forward(*inputs):
                    return module(*inputs)

                return custom_forward

            if is_torch_version(">=", "1.11.0"):
                x = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.layers), x, use_reentrant=False
                )
            else:
                x = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.layers), x
                )

        else:
            x = self.layers(x)

        # scale image from [0, 1] to [-1, 1] to match diffusers convention
        return x.mul(2).sub(1)
